import sklearn.datasets as dataset
from sklearn.cluster import KMeans
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

data = pd.read_csv('asiazoo.txt',header=None,delimiter='\t')
data.columns=["国家","2006世界杯","2010世界杯","2007亚洲杯"]

kmn = KMeans(n_clusters=3)
x_train = data[["2006世界杯","2010世界杯","2007亚洲杯"]]
y_new = kmn.fit(x_train).predict(x_train)
# print(x_train)
print(y_new)
print(data.loc[0,"国家"])
d = data["国家"]
for i in range(3):
    country = np.where(kmn.predict(x_train)==i)
    for j in range(0,len(country)):
        print(d[country[j]])
    print('\n')

centers = kmn.cluster_centers_
x,y,z = data['2006世界杯'],data['2010世界杯'],data['2007亚洲杯']
fig = plt.figure(figsize=(12,8))
ax = fig.gca(projection='3d')
# ax = fig.add_subplot(111,projection='3d')

ax.scatter(x,y,z,c=kmn.predict(x_train),cmap='cool')
ax.scatter(centers[:,0],centers[:,1],centers[:,2],c=kmn.predict(centers),
           cmap='rainbow')
ax.set_xlabel('x')
ax.set_xlabel('y')
ax.set_xlabel('z')
plt.savefig('3d.png')